{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "7df588cb",
   "metadata": {
    "slideshow": {
     "slide_type": "-"
    }
   },
   "source": [
    "# The Dataset for Pretraining BERT\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b349c114",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:27:51.349012Z",
     "iopub.status.busy": "2023-08-18T19:27:51.348380Z",
     "iopub.status.idle": "2023-08-18T19:27:54.176466Z",
     "shell.execute_reply": "2023-08-18T19:27:54.175190Z"
    },
    "origin_pos": 2,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import random\n",
    "import torch\n",
    "from d2l import torch as d2l"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ead17d98",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "The WikiText-2 dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ea9beb6f",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:27:54.181061Z",
     "iopub.status.busy": "2023-08-18T19:27:54.180516Z",
     "iopub.status.idle": "2023-08-18T19:27:54.187506Z",
     "shell.execute_reply": "2023-08-18T19:27:54.186224Z"
    },
    "origin_pos": 4,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "d2l.DATA_HUB['wikitext-2'] = (\n",
    "    'https://s3.amazonaws.com/research.metamind.io/wikitext/'\n",
    "    'wikitext-2-v1.zip', '3c914d17d80b1459be871a5039ac23e752a53cbe')\n",
    "\n",
    "def _read_wiki(data_dir):\n",
    "    file_name = os.path.join(data_dir, 'wiki.train.tokens')\n",
    "    with open(file_name, 'r') as f:\n",
    "        lines = f.readlines()\n",
    "    paragraphs = [line.strip().lower().split(' . ')\n",
    "                  for line in lines if len(line.split(' . ')) >= 2]\n",
    "    random.shuffle(paragraphs)\n",
    "    return paragraphs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c937fed9",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "Generating the Next Sentence Prediction Task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8a515271",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:27:54.204406Z",
     "iopub.status.busy": "2023-08-18T19:27:54.204109Z",
     "iopub.status.idle": "2023-08-18T19:27:54.210105Z",
     "shell.execute_reply": "2023-08-18T19:27:54.209309Z"
    },
    "origin_pos": 8,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "def _get_next_sentence(sentence, next_sentence, paragraphs):\n",
    "    if random.random() < 0.5:\n",
    "        is_next = True\n",
    "    else:\n",
    "        next_sentence = random.choice(random.choice(paragraphs))\n",
    "        is_next = False\n",
    "    return sentence, next_sentence, is_next\n",
    "\n",
    "def _get_nsp_data_from_paragraph(paragraph, paragraphs, vocab, max_len):\n",
    "    nsp_data_from_paragraph = []\n",
    "    for i in range(len(paragraph) - 1):\n",
    "        tokens_a, tokens_b, is_next = _get_next_sentence(\n",
    "            paragraph[i], paragraph[i + 1], paragraphs)\n",
    "        if len(tokens_a) + len(tokens_b) + 3 > max_len:\n",
    "            continue\n",
    "        tokens, segments = d2l.get_tokens_and_segments(tokens_a, tokens_b)\n",
    "        nsp_data_from_paragraph.append((tokens, segments, is_next))\n",
    "    return nsp_data_from_paragraph"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "553e00b5",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "Generating the Masked Language Modeling Task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c71c28c0",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:27:54.224660Z",
     "iopub.status.busy": "2023-08-18T19:27:54.224368Z",
     "iopub.status.idle": "2023-08-18T19:27:54.231144Z",
     "shell.execute_reply": "2023-08-18T19:27:54.230276Z"
    },
    "origin_pos": 12,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "def _replace_mlm_tokens(tokens, candidate_pred_positions, num_mlm_preds,\n",
    "                        vocab):\n",
    "    mlm_input_tokens = [token for token in tokens]\n",
    "    pred_positions_and_labels = []\n",
    "    random.shuffle(candidate_pred_positions)\n",
    "    for mlm_pred_position in candidate_pred_positions:\n",
    "        if len(pred_positions_and_labels) >= num_mlm_preds:\n",
    "            break\n",
    "        masked_token = None\n",
    "        if random.random() < 0.8:\n",
    "            masked_token = '<mask>'\n",
    "        else:\n",
    "            if random.random() < 0.5:\n",
    "                masked_token = tokens[mlm_pred_position]\n",
    "            else:\n",
    "                masked_token = random.choice(vocab.idx_to_token)\n",
    "        mlm_input_tokens[mlm_pred_position] = masked_token\n",
    "        pred_positions_and_labels.append(\n",
    "            (mlm_pred_position, tokens[mlm_pred_position]))\n",
    "    return mlm_input_tokens, pred_positions_and_labels\n",
    "\n",
    "def _get_mlm_data_from_tokens(tokens, vocab):\n",
    "    candidate_pred_positions = []\n",
    "    for i, token in enumerate(tokens):\n",
    "        if token in ['<cls>', '<sep>']:\n",
    "            continue\n",
    "        candidate_pred_positions.append(i)\n",
    "    num_mlm_preds = max(1, round(len(tokens) * 0.15))\n",
    "    mlm_input_tokens, pred_positions_and_labels = _replace_mlm_tokens(\n",
    "        tokens, candidate_pred_positions, num_mlm_preds, vocab)\n",
    "    pred_positions_and_labels = sorted(pred_positions_and_labels,\n",
    "                                       key=lambda x: x[0])\n",
    "    pred_positions = [v[0] for v in pred_positions_and_labels]\n",
    "    mlm_pred_labels = [v[1] for v in pred_positions_and_labels]\n",
    "    return vocab[mlm_input_tokens], pred_positions, vocab[mlm_pred_labels]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d2269ffa",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "Append the special “&lt;pad&gt;” tokens to the inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "9b2d2247",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:27:54.234922Z",
     "iopub.status.busy": "2023-08-18T19:27:54.234060Z",
     "iopub.status.idle": "2023-08-18T19:27:54.243205Z",
     "shell.execute_reply": "2023-08-18T19:27:54.242138Z"
    },
    "origin_pos": 15,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "def _pad_bert_inputs(examples, max_len, vocab):\n",
    "    max_num_mlm_preds = round(max_len * 0.15)\n",
    "    all_token_ids, all_segments, valid_lens,  = [], [], []\n",
    "    all_pred_positions, all_mlm_weights, all_mlm_labels = [], [], []\n",
    "    nsp_labels = []\n",
    "    for (token_ids, pred_positions, mlm_pred_label_ids, segments,\n",
    "         is_next) in examples:\n",
    "        all_token_ids.append(torch.tensor(token_ids + [vocab['<pad>']] * (\n",
    "            max_len - len(token_ids)), dtype=torch.long))\n",
    "        all_segments.append(torch.tensor(segments + [0] * (\n",
    "            max_len - len(segments)), dtype=torch.long))\n",
    "        valid_lens.append(torch.tensor(len(token_ids), dtype=torch.float32))\n",
    "        all_pred_positions.append(torch.tensor(pred_positions + [0] * (\n",
    "            max_num_mlm_preds - len(pred_positions)), dtype=torch.long))\n",
    "        all_mlm_weights.append(\n",
    "            torch.tensor([1.0] * len(mlm_pred_label_ids) + [0.0] * (\n",
    "                max_num_mlm_preds - len(pred_positions)),\n",
    "                dtype=torch.float32))\n",
    "        all_mlm_labels.append(torch.tensor(mlm_pred_label_ids + [0] * (\n",
    "            max_num_mlm_preds - len(mlm_pred_label_ids)), dtype=torch.long))\n",
    "        nsp_labels.append(torch.tensor(is_next, dtype=torch.long))\n",
    "    return (all_token_ids, all_segments, valid_lens, all_pred_positions,\n",
    "            all_mlm_weights, all_mlm_labels, nsp_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "66eebc05",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "The WikiText-2 dataset for pretraining BERT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "27a00351",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:27:54.247275Z",
     "iopub.status.busy": "2023-08-18T19:27:54.246496Z",
     "iopub.status.idle": "2023-08-18T19:27:54.255307Z",
     "shell.execute_reply": "2023-08-18T19:27:54.254197Z"
    },
    "origin_pos": 18,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "class _WikiTextDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, paragraphs, max_len):\n",
    "        paragraphs = [d2l.tokenize(\n",
    "            paragraph, token='word') for paragraph in paragraphs]\n",
    "        sentences = [sentence for paragraph in paragraphs\n",
    "                     for sentence in paragraph]\n",
    "        self.vocab = d2l.Vocab(sentences, min_freq=5, reserved_tokens=[\n",
    "            '<pad>', '<mask>', '<cls>', '<sep>'])\n",
    "        examples = []\n",
    "        for paragraph in paragraphs:\n",
    "            examples.extend(_get_nsp_data_from_paragraph(\n",
    "                paragraph, paragraphs, self.vocab, max_len))\n",
    "        examples = [(_get_mlm_data_from_tokens(tokens, self.vocab)\n",
    "                      + (segments, is_next))\n",
    "                     for tokens, segments, is_next in examples]\n",
    "        (self.all_token_ids, self.all_segments, self.valid_lens,\n",
    "         self.all_pred_positions, self.all_mlm_weights,\n",
    "         self.all_mlm_labels, self.nsp_labels) = _pad_bert_inputs(\n",
    "            examples, max_len, self.vocab)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return (self.all_token_ids[idx], self.all_segments[idx],\n",
    "                self.valid_lens[idx], self.all_pred_positions[idx],\n",
    "                self.all_mlm_weights[idx], self.all_mlm_labels[idx],\n",
    "                self.nsp_labels[idx])\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.all_token_ids)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ae671041",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "Download and WikiText-2 dataset\n",
    "and generate pretraining examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "6e8efaa8",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:27:54.258912Z",
     "iopub.status.busy": "2023-08-18T19:27:54.258166Z",
     "iopub.status.idle": "2023-08-18T19:27:54.264300Z",
     "shell.execute_reply": "2023-08-18T19:27:54.263141Z"
    },
    "origin_pos": 21,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "def load_data_wiki(batch_size, max_len):\n",
    "    \"\"\"Load the WikiText-2 dataset.\"\"\"\n",
    "    num_workers = d2l.get_dataloader_workers()\n",
    "    data_dir = d2l.download_extract('wikitext-2', 'wikitext-2')\n",
    "    paragraphs = _read_wiki(data_dir)\n",
    "    train_set = _WikiTextDataset(paragraphs, max_len)\n",
    "    train_iter = torch.utils.data.DataLoader(train_set, batch_size,\n",
    "                                        shuffle=True, num_workers=num_workers)\n",
    "    return train_iter, train_set.vocab"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11b44928",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "Print out the shapes of a minibatch of BERT pretraining examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "6b866b9c",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:27:54.267688Z",
     "iopub.status.busy": "2023-08-18T19:27:54.267106Z",
     "iopub.status.idle": "2023-08-18T19:28:02.924683Z",
     "shell.execute_reply": "2023-08-18T19:28:02.923277Z"
    },
    "origin_pos": 23,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading ../data/wikitext-2-v1.zip from https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip...\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([512, 64]) torch.Size([512, 64]) torch.Size([512]) torch.Size([512, 10]) torch.Size([512, 10]) torch.Size([512, 10]) torch.Size([512])\n"
     ]
    }
   ],
   "source": [
    "batch_size, max_len = 512, 64\n",
    "train_iter, vocab = load_data_wiki(batch_size, max_len)\n",
    "\n",
    "for (tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X,\n",
    "     mlm_Y, nsp_y) in train_iter:\n",
    "    print(tokens_X.shape, segments_X.shape, valid_lens_x.shape,\n",
    "          pred_positions_X.shape, mlm_weights_X.shape, mlm_Y.shape,\n",
    "          nsp_y.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "76cf9172",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:28:02.929142Z",
     "iopub.status.busy": "2023-08-18T19:28:02.928158Z",
     "iopub.status.idle": "2023-08-18T19:28:02.936806Z",
     "shell.execute_reply": "2023-08-18T19:28:02.935612Z"
    },
    "origin_pos": 25,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "20256"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(vocab)"
   ]
  }
 ],
 "metadata": {
  "celltoolbar": "Slideshow",
  "language_info": {
   "name": "python"
  },
  "required_libs": [],
  "rise": {
   "autolaunch": true,
   "enable_chalkboard": true,
   "overlay": "<div class='my-top-right'><img height=80px src='http://d2l.ai/_static/logo-with-text.png'/></div><div class='my-top-left'></div>",
   "scroll": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}